# Table 1 in "Evaluating (weighted) dynamic treatment effects by double machine learning" 
# by Hugo Bodory, Martin Huber, and Luk� Laff�rs

rm(list = ls())
packages = c('SuperLearner','e1071','glmnet','ranger','xgboost','tictoc','parallel','writexl','ggplot2','dplyr','mvtnorm','MASS','DescTools')
lapply(packages,FUN=function(packages) {do.call("require", list(packages))}) 

# 18 (pseudo)R2 objects
r2_d1_n1_p1=c();r2_d1_n1_p2=c();r2_d1_n1_p3=c();r2_d1_n2_p1=c();r2_d1_n2_p2=c();r2_d1_n2_p3=c()
r2_d2_n1_p1=c();r2_d2_n1_p2=c();r2_d2_n1_p3=c();r2_d2_n2_p1=c();r2_d2_n2_p2=c();r2_d2_n2_p3=c()
r2_y2_n1_p1=c();r2_y2_n1_p2=c();r2_y2_n1_p3=c();r2_y2_n2_p1=c();r2_y2_n2_p2=c();r2_y2_n2_p3=c()

n_all=c(2500,10000)  # sample size
p0_all=c(50,100,500)  # number of covariates

effd1=1  # effect of first treatment
effd2=1  # effect of second treatment
het=0  # indicator for effect heterogeneity (switched off if zero)
decrease=1  # indicator for decreasing importance of coefficients
xcorr=1  # indicator for non-zero covariance between regressors    

# loops over sample size, number of covariates, and simulations for
# creating DGPs and computing (pseudo)R2 stats
for (i_n in 1:2){  
  n = n_all[i_n]
  if (n==2500)  numsim=1000
  if (n==10000) numsim=250
  for (i_p0 in 1:3){
    p0 = p0_all[i_p0]
    s0=p0; p1=p0; s1=p0;
    for (i in 1:numsim){
      set.seed(i)
      if (xcorr!=0){
        sigma0=matrix(0,p0,p0)
        for (ii in 1:p0){
          for (jj in 1:p0){
            sigma0[ii,jj] =0.5^(abs(ii-jj))   # covariate matrix for x0
          }
        }
        x0=(rmvnorm(n,rep(0,p0),sigma0))  # covariate matrix at baseline
      }
      if (decrease!=0){
        beta0=rep(0,p0)
        for (j in 1:s0) beta0[j]=(.4/j^4)
        beta1=rep(0,p1)
        for (j in 1:s1) beta1[j]=(.4/j^4)
      }
      d1=(x0%*%beta0+rnorm(n)>0)*1  # equation of first treatment in period 1
      if (xcorr!=0){
        sigma1=matrix(0,p1,p1)
        for (ii in 1:p1){
          for (jj in 1:p1){
            sigma1[ii,jj] =0.5^(abs(ii-jj))   # covariate matrix for x1
          }
        }
        x1=(rmvnorm(n,rep(0,p1),sigma1))  # covariate matrix of period 1
      }
      d2=(x0%*%beta0+x1%*%beta1+0.3*d1+rnorm(n)>0)*1  # equation of second treatment in period 2
      u=rnorm(n)
      y2=x0%*%beta0+x1%*%beta1+effd1*d1+het*x0[,1]*d1+effd2*d2+u  # outcome equation in period 2
      
      px0 = glm(d1~x0,family=binomial(probit))
      pd1x0x1 = glm(d2~cbind(d1,x0,x1),family=binomial(probit))
      r2_nagelkerke_d1_x0 = PseudoR2(px0,"Nagelkerke")
      r2_nagelkerke_d2_d1x0x1 = PseudoR2(pd1x0x1,"Nagelkerke")
      r2_y2_x0x1 = summary(lm(y2~cbind(x0,x1)))$r.squared

      if (n== 2500&p0== 50) {r2_d1_n1_p1=c(r2_d1_n1_p1,r2_nagelkerke_d1_x0);r2_d2_n1_p1=c(r2_d2_n1_p1,r2_nagelkerke_d2_d1x0x1);r2_y2_n1_p1=c(r2_y2_n1_p1,r2_y2_x0x1)}
      if (n== 2500&p0==100) {r2_d1_n1_p2=c(r2_d1_n1_p2,r2_nagelkerke_d1_x0);r2_d2_n1_p2=c(r2_d2_n1_p2,r2_nagelkerke_d2_d1x0x1);r2_y2_n1_p2=c(r2_y2_n1_p2,r2_y2_x0x1)}
      if (n== 2500&p0==500) {r2_d1_n1_p3=c(r2_d1_n1_p3,r2_nagelkerke_d1_x0);r2_d2_n1_p3=c(r2_d2_n1_p3,r2_nagelkerke_d2_d1x0x1);r2_y2_n1_p3=c(r2_y2_n1_p3,r2_y2_x0x1)}
      if (n==10000&p0== 50) {r2_d1_n2_p1=c(r2_d1_n2_p1,r2_nagelkerke_d1_x0);r2_d2_n2_p1=c(r2_d2_n2_p1,r2_nagelkerke_d2_d1x0x1);r2_y2_n2_p1=c(r2_y2_n2_p1,r2_y2_x0x1)}
      if (n==10000&p0==100) {r2_d1_n2_p2=c(r2_d1_n2_p2,r2_nagelkerke_d1_x0);r2_d2_n2_p2=c(r2_d2_n2_p2,r2_nagelkerke_d2_d1x0x1);r2_y2_n2_p2=c(r2_y2_n2_p2,r2_y2_x0x1)}
      if (n==10000&p0==500) {r2_d1_n2_p3=c(r2_d1_n2_p3,r2_nagelkerke_d1_x0);r2_d2_n2_p3=c(r2_d2_n2_p3,r2_nagelkerke_d2_d1x0x1);r2_y2_n2_p3=c(r2_y2_n2_p3,r2_y2_x0x1)}
    }
  }
}

#results
average_r2_d1_n1_p1=mean(r2_d1_n1_p1);average_r2_d2_n1_p1=mean(r2_d2_n1_p1);average_r2_y2_n1_p1=mean(r2_y2_n1_p1)
average_r2_d1_n1_p2=mean(r2_d1_n1_p2);average_r2_d2_n1_p2=mean(r2_d2_n1_p2);average_r2_y2_n1_p2=mean(r2_y2_n1_p2)
average_r2_d1_n1_p3=mean(r2_d1_n1_p3);average_r2_d2_n1_p3=mean(r2_d2_n1_p3[abs(r2_d2_n1_p3)<1]);average_r2_y2_n1_p3=mean(r2_y2_n1_p3)
average_r2_d1_n2_p1=mean(r2_d1_n2_p1);average_r2_d2_n2_p1=mean(r2_d2_n2_p1);average_r2_y2_n2_p1=mean(r2_y2_n2_p1)
average_r2_d1_n2_p2=mean(r2_d1_n2_p2);average_r2_d2_n2_p2=mean(r2_d2_n2_p2);average_r2_y2_n2_p2=mean(r2_y2_n2_p2)
average_r2_d1_n2_p3=mean(r2_d1_n2_p3);average_r2_d2_n2_p3=mean(r2_d2_n2_p3);average_r2_y2_n2_p3=mean(r2_y2_n2_p3)
